{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# QPyTorch Functionality Overview\n",
    "\n",
    "## Introduction\n",
    "In this notebook, we provide an overview of the major features of QPyTorch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import qtorch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Quantization\n",
    "\n",
    "QPyTorch supports three different number formats: fixed point, block floating point, and floating point.\n",
    "\n",
    "QPyTorch provides quantization functions that quantizes pytorch tensor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qtorch.quant import fixed_point_quantize, block_quantize, float_quantize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Full Precision: tensor([0.1241, 0.3602, 0.7104, 0.8344, 0.0211])\n",
      "Low Precision: tensor([0.1250, 0.3750, 0.7500, 0.8750, 0.0195])\n"
     ]
    }
   ],
   "source": [
    "full_precision_tensor = torch.rand(5)\n",
    "print(\"Full Precision: {}\".format(full_precision_tensor))\n",
    "low_precision_tensor = float_quantize(full_precision_tensor, exp=5, man=2, rounding=\"nearest\")\n",
    "print(\"Low Precision: {}\".format(low_precision_tensor))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "QPyTorch supports both nearest rounding and stochastic rounding. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Nearest: tensor([0.1250, 0.3750, 0.7500, 0.8750, 0.0195])\n",
      "Stochastic: tensor([0.1250, 0.3750, 0.7500, 0.8750, 0.0195])\n"
     ]
    }
   ],
   "source": [
    "nearest_rounded = float_quantize(full_precision_tensor, exp=5, man=2, rounding=\"nearest\")\n",
    "stochastic_rounded = float_quantize(full_precision_tensor, exp=5, man=2, rounding=\"stochastic\")\n",
    "print(\"Nearest: {}\".format(nearest_rounded))\n",
    "print(\"Stochastic: {}\".format(stochastic_rounded))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Autograd\n",
    "\n",
    "QPyTorch offers a pytorch nn.Module wrapper to integrate quantization into auto differention. A Quantizer module can use different low-precision number formats for forward and backward propagation. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# First define number formats used in forward and backward quantization\n",
    "from qtorch import FixedPoint, FloatingPoint\n",
    "forward_num = FixedPoint(wl=4, fl=2)\n",
    "backward_num = FloatingPoint(exp=5, man=2)\n",
    "\n",
    "# Create a quantizer\n",
    "from qtorch.quant import Quantizer\n",
    "Q = Quantizer(forward_number=forward_num, backward_number=backward_num,\n",
    "              forward_rounding=\"nearest\", backward_rounding=\"stochastic\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use QPyTorch Quantizer just as any other nn.Modules\n",
    "from torch.nn import Module, Linear\n",
    "class LinearLP(Module):\n",
    "    \"\"\"\n",
    "    a low precision Logistic Regression model\n",
    "    \"\"\"\n",
    "    def __init__(self):\n",
    "        super(LinearLP, self).__init__()\n",
    "        self.W = Linear(5, 1)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        out = self.W(x)\n",
    "        out = Q(out)\n",
    "        return out\n",
    "    \n",
    "lp_model = LinearLP()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Low Precision Output: tensor([-0.2500], grad_fn=<RoundingBackward>)\n"
     ]
    }
   ],
   "source": [
    "# forward low precision model, get low precision output\n",
    "fake_input = torch.rand(5)\n",
    "lp_output = lp_model(fake_input)\n",
    "print(\"Low Precision Output: {}\".format(lp_output))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# backward propagation is quantized automatically\n",
    "from torch import sigmoid\n",
    "from torch.nn import BCELoss\n",
    "lp_model.zero_grad()\n",
    "criterion = BCELoss()\n",
    "label = torch.Tensor([0])\n",
    "loss = criterion(sigmoid(lp_output), label)\n",
    "loss.backward()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Low Precision Optimization\n",
    "\n",
    "### Weight and Gradient Quantization\n",
    "In the previous example, the forward and backward signals are quantized into low precision. However, if we optimize our model using gradient descent, the weight and gradient may not necessarily be low precision. QPyTorch offers a low precision wrapper for pytorch optimizers and abstracts the quantization of weights, gradients, and the momentum velocity vectors. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.optim import SGD\n",
    "from qtorch.optim import OptimLP\n",
    "\n",
    "optimizer = SGD(lp_model.parameters(), momentum=0.9, lr=0.1) # use your favorite optimizer\n",
    "# define custom quantization functions for different numbers\n",
    "weight_quant = lambda x : float_quantize(x, exp=5, man=2, rounding=\"nearest\")\n",
    "gradient_quant = lambda x : float_quantize(x, exp=5, man=2, rounding=\"nearest\")\n",
    "momentum_quant = lambda x : float_quantize(x, exp=6, man=9, rounding=\"nearest\")\n",
    "# turn your optimizer into a low precision optimizer\n",
    "optimizer = OptimLP(optimizer, \n",
    "                    weight_quant=weight_quant, \n",
    "                    grad_quant=gradient_quant, \n",
    "                    momentum_quant=momentum_quant)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weight before optimizer stepping: \n",
      "tensor([[-0.1850,  0.1250, -0.1007, -0.0862,  0.3034]])\n",
      "Gradient before optimizer stepping: \n",
      "tensor([[0.1051, 0.2755, 0.0375, 0.1643, 0.1883]])\n",
      "\n",
      "Weight after optimizer stepping: \n",
      "tensor([[-0.1875,  0.0938, -0.1094, -0.1094,  0.3125]])\n",
      "Gradient after optimizer stepping: \n",
      "tensor([[0.1094, 0.2500, 0.0391, 0.1562, 0.1875]])\n"
     ]
    }
   ],
   "source": [
    "print(\"Weight before optimizer stepping: \\n{}\".format(lp_model.W.weight.data))\n",
    "print(\"Gradient before optimizer stepping: \\n{}\\n\".format(lp_model.W.weight.grad))\n",
    "optimizer.step()\n",
    "print(\"Weight after optimizer stepping: \\n{}\".format(lp_model.W.weight.data))\n",
    "print(\"Gradient after optimizer stepping: \\n{}\".format(lp_model.W.weight.grad))\n",
    "optimizer.zero_grad() #"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gradient Accumulator\n",
    "One popular practice in low precision training is to utilize a higher precision gradient accumulator. The gradients, after multiplied with learning rate, modified by the momentum terms, are added onto the high precision gradient accumulator. Upon next iteration of forward and backward propagation, the weights are re-quantized from the gradient accumulator so expensive computations are still done in low precision. \n",
    "\n",
    "QPyTorch integrates this process into the low precision optimizer. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Low Precision Output: tensor([0.2500], grad_fn=<RoundingBackward>)\n"
     ]
    }
   ],
   "source": [
    "# Let's quickly repeat the above example\n",
    "lp_model = LinearLP()\n",
    "fake_input = torch.rand(5)\n",
    "lp_output = lp_model(fake_input)\n",
    "print(\"Low Precision Output: {}\".format(lp_output))\n",
    "lp_model.zero_grad()\n",
    "criterion = BCELoss()\n",
    "label = torch.Tensor([0])\n",
    "loss = criterion(sigmoid(lp_output), label)\n",
    "loss.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define a low precision optimizer with gradient accumulators\n",
    "optimizer = SGD(lp_model.parameters(), momentum=0, lr=0.1)\n",
    "weight_quant = lambda x : float_quantize(x, exp=5, man=2, rounding=\"nearest\")\n",
    "gradient_quant = lambda x : float_quantize(x, exp=5, man=2, rounding=\"nearest\")\n",
    "acc_quant = lambda x : float_quantize(x, exp=6, man=9, rounding=\"nearest\") # use higher precision for accumulator\n",
    "optimizer = OptimLP(optimizer, \n",
    "                    weight_quant=weight_quant, \n",
    "                    grad_quant=gradient_quant, \n",
    "                    momentum_quant=momentum_quant,\n",
    "                    acc_quant=acc_quant)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Weight before optimizer stepping: \n",
      "tensor([[ 0.2943, -0.1296, -0.4130, -0.2599, -0.4059]])\n",
      "\n",
      "after stepping, high precision accumulator : \n",
      "tensor([[ 0.2817, -0.1309, -0.4253, -0.2603, -0.4185]])\n",
      "after stepping, low precision weight : \n",
      "tensor([[ 0.3125, -0.1250, -0.4375, -0.2500, -0.4375]])\n"
     ]
    }
   ],
   "source": [
    "print(\"Weight before optimizer stepping: \\n{}\\n\".format(lp_model.W.weight.data))\n",
    "optimizer.step()\n",
    "print(\"after stepping, high precision accumulator : \\n{}\".format(optimizer.weight_acc[lp_model.W.weight]))\n",
    "print(\"after stepping, low precision weight : \\n{}\".format(lp_model.W.weight.data))\n",
    "optimizer.zero_grad()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## High-level Helper\n",
    "\n",
    "QPytorch also provide a useful helper that automatically turn a predefined pytorch model into a low-precision one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qtorch.auto_low import sequential_lower\n",
    "class LinearFP(Module):\n",
    "    \"\"\"\n",
    "    a low precision Logistic Regression model\n",
    "    \"\"\"\n",
    "    def __init__(self):\n",
    "        super(LinearFP, self).__init__()\n",
    "        self.W = Linear(5, 1)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        out = self.W(x)\n",
    "        return out\n",
    "    \n",
    "fp_model = LinearFP()\n",
    "\n",
    "forward_num = FixedPoint(wl=4, fl=2)\n",
    "backward_num = FloatingPoint(exp=5, man=2)\n",
    "lp_model = sequential_lower(fp_model, layer_types=['linear'],\n",
    "                            forward_number=forward_num, backward_number=backward_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Full Precision Model: \n",
      "LinearFP(\n",
      "  (W): Linear(in_features=5, out_features=1, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(\"Full Precision Model: \")\n",
    "print(fp_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Low Precision Model: \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "LinearFP(\n",
       "  (W): Sequential(\n",
       "    (0): Linear(in_features=5, out_features=1, bias=True)\n",
       "    (1): Quantizer()\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(\"Low Precision Model: \")\n",
    "lp_model"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
